{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1> 2c. Refactoring to add batching and feature-creation </h1>\n",
    "\n",
    "In this notebook, we continue reading the same small dataset, but refactor our ML pipeline in two small, but significant, ways:\n",
    "<ol>\n",
    "<li> Refactor the input to read data in batches.\n",
    "<li> Refactor the feature creation so that it is not one-to-one with inputs.\n",
    "</ol>\n",
    "The Pandas function in the previous notebook also batched, only after it had read the whole data into memory -- on a large dataset, this won't be an option."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/envs/py3env/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8.0\n"
     ]
    }
   ],
   "source": [
    "import datalab.bigquery as bq\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import shutil\n",
    "import tensorflow as tf\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2> 1. Refactor the input </h2>\n",
    "\n",
    "Read data created in Lab1a, but this time make it more general and performant.  Instead of using Pandas, we will use TensorFlow's Dataset API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "CSV_COLUMNS = ['fare_amount', 'pickuplon','pickuplat','dropofflon','dropofflat','passengers', 'key']\n",
    "LABEL_COLUMN = 'fare_amount'\n",
    "DEFAULTS = [[0.0], [-74.0], [40.0], [-74.0], [40.7], [1.0], ['nokey']]\n",
    "\n",
    "def read_dataset(filename, mode, batch_size = 512):\n",
    "  def _input_fn():\n",
    "    def decode_csv(value_column):\n",
    "      columns = tf.decode_csv(value_column, record_defaults = DEFAULTS)\n",
    "      features = dict(zip(CSV_COLUMNS, columns))\n",
    "      label = features.pop(LABEL_COLUMN)\n",
    "      return features, label\n",
    "\n",
    "    # Create list of files that match pattern\n",
    "    file_list = tf.gfile.Glob(filename)\n",
    "\n",
    "    # Create dataset from file list\n",
    "    dataset = tf.data.TextLineDataset(file_list).map(decode_csv)\n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        num_epochs = None # indefinitely\n",
    "        dataset = dataset.shuffle(buffer_size = 10 * batch_size)\n",
    "    else:\n",
    "        num_epochs = 1 # end-of-input after this\n",
    "\n",
    "    dataset = dataset.repeat(num_epochs).batch(batch_size)\n",
    "    return dataset.make_one_shot_iterator().get_next()\n",
    "  return _input_fn\n",
    "    \n",
    "\n",
    "def get_train():\n",
    "  return read_dataset('./taxi-train.csv', mode = tf.estimator.ModeKeys.TRAIN)\n",
    "\n",
    "def get_valid():\n",
    "  return read_dataset('./taxi-valid.csv', mode = tf.estimator.ModeKeys.EVAL)\n",
    "\n",
    "def get_test():\n",
    "  return read_dataset('./taxi-test.csv', mode = tf.estimator.ModeKeys.EVAL)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2> 2. Refactor the way features are created. </h2>\n",
    "\n",
    "For now, pass these through (same as previous lab).  However, refactoring this way will enable us to break the one-to-one relationship between inputs and features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_COLUMNS = [\n",
    "    tf.feature_column.numeric_column('pickuplon'),\n",
    "    tf.feature_column.numeric_column('pickuplat'),\n",
    "    tf.feature_column.numeric_column('dropofflat'),\n",
    "    tf.feature_column.numeric_column('dropofflon'),\n",
    "    tf.feature_column.numeric_column('passengers'),\n",
    "]\n",
    "\n",
    "def add_more_features(feats):\n",
    "  # Nothing to add (yet!)\n",
    "  return feats\n",
    "\n",
    "feature_cols = add_more_features(INPUT_COLUMNS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2> Create and train the model </h2>\n",
    "\n",
    "Note that we train for num_steps * batch_size examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using default config.\n",
      "INFO:tensorflow:Using config: {'_evaluation_master': '', '_global_id_in_cluster': 0, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f1684957828>, '_num_ps_replicas': 0, '_save_summary_steps': 100, '_save_checkpoints_secs': 600, '_num_worker_replicas': 1, '_model_dir': 'taxi_trained', '_log_step_count_steps': 100, '_save_checkpoints_steps': None, '_is_chief': True, '_train_distribute': None, '_task_id': 0, '_tf_random_seed': None, '_master': '', '_keep_checkpoint_every_n_hours': 10000, '_task_type': 'worker', '_service': None, '_session_config': None, '_keep_checkpoint_max': 5}\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 1 into taxi_trained/model.ckpt.\n",
      "INFO:tensorflow:step = 1, loss = 104745.32\n",
      "INFO:tensorflow:Saving checkpoints for 100 into taxi_trained/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 66315.17.\n"
     ]
    }
   ],
   "source": [
    "tf.logging.set_verbosity(tf.logging.INFO)\n",
    "OUTDIR = 'taxi_trained'\n",
    "shutil.rmtree(OUTDIR, ignore_errors = True) # start fresh each time\n",
    "model = tf.estimator.LinearRegressor(\n",
    "      feature_columns = feature_cols, model_dir = OUTDIR)\n",
    "model.train(input_fn = get_train(), steps = 100);  # TODO: change the name of input_fn as needed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h3> Evaluate model </h3>\n",
    "\n",
    "As before, evaluate on the validation data.  We'll do the third refactoring (to move the evaluation into the training loop) in the next lab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Starting evaluation at 2019-01-08-03:38:41\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from taxi_trained/model.ckpt-100\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Evaluation [1/1]\n",
      "INFO:tensorflow:Finished evaluation at 2019-01-08-03:38:42\n",
      "INFO:tensorflow:Saving dict for global step 100: average_loss = 108.89749, global_step = 100, loss = 55755.516\n",
      "RMSE on validation dataset = 10.435396194458008\n"
     ]
    }
   ],
   "source": [
    "def print_rmse(model, name, input_fn):\n",
    "  metrics = model.evaluate(input_fn = input_fn, steps = 1)\n",
    "  print('RMSE on {} dataset = {}'.format(name, np.sqrt(metrics['average_loss'])))\n",
    "print_rmse(model, 'validation', get_valid())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2017 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
